
import math


def configure_connections(args):
    """
    Only for Mixers
    """
    if args.fix == -1:
        """Ignore fixing connections
        """
        assert args.dim_token > 0
        ef = args.expansion_factor if args.expansion_factor > 0 else 1
        args.max_dim = ef*args.dim*args.dim_token
        args.prod_dim = args.dim*args.dim_token
        args.num_connections =  ef*args.dim*args.dim_token*(args.dim+args.dim_token)/2
        args.num_params = ef*(args.dim**2+args.dim_token**2)/2
        if args.dim_ppfc !=  args.dim:
            print("Warning: default setting of dim_ppfc is not equal to dim !")

    elif args.fix== 2:
        """Fixing given num_connections
        """
        assert args.num_connections > 0
        nc = args.num_connections
        ef = args.expansion_factor if args.expansion_factor > 0 else 1
        c = args.dim
        s = -c + math.sqrt( c**2 + 8*nc/ (ef*c) )
        s /= 2
        s = round(s)
        diff=100*(ef*s*c*(s+c)/2 )/ nc -100
        if abs(diff) < 0.06:
            print(f"{(args.dim, s)} : connection diff by round = {(diff)} %")
        
        args.dim_token = s
        args.prod_dim = s*c
        args.max_dim = s*c*ef
        args.num_params = (s**2 + c**2)*ef/2
        #print("max_dim:", args.max_dim)

    elif args.fix == 3:
        """ Determine expansion_factor and dim_token
        from dim, num_connections, patch_size, img_size:
                    2 \Omega
        \gamma = ------------------,   m = sc
                    (s+c)m
        """
        assert args.prod_dim > 0
        assert args.num_connections > 0
        assert (args.img_size % args.patch_size) == 0
        s0=(args.img_size//args.patch_size)**2
        m = args.dim_ppfc*s0
        args.prod_dim = m
        c = args.dim
        s = m/c
        if m % c != 0:
            print(f"Warning: shape missmatch can happen: m={(m)}, c={(c)}, s={(s)}")
        ### compute expanding factor from s and c 
        ef = 2*args.num_connections/(m*(s+c) )
        args.expansion_factor = ef
        args.dim_token = round(s)
        args.max_dim = round(ef*s*c)
        args.num_params = (s**2 + c**2)*ef/2
        #print("max_dim:", args.max_dim)

    elif args.fix == 4:
        """_summary_
        compute dim, dim_token from 
            patch_size, num_connections, expansion_factor_ppfc.
                    2\Omega
        \gamma = ----------------
                   prod_dim ( C+ prod_dim c)
        
        c = 3 p^2 \gamma_0
        s = (img_size/patch_size)**2
        
        """       
        assert args.expansion_factor_ppfc > 0
        c = 3*(args.patch_size**2)*args.expansion_factor_ppfc
        s = (args.img_size/args.patch_size)**2
        pd = c*s
        ef = 2*args.num_connections/(pd*(c+s))
        m = ef*pd
        #print(args.patch_size,m)
        args.dim=round(c)
        args.dim_token=round(s)
        args.prod_dim=round(pd)
        args.expansion_factor=ef
        args.max_dim=round(m)
        args.num_params=ef*(c**2+s**2)/2
        args.dim_ppfc=round(c)
        

def main(dim=None, patch=None):
    parser = argparse.ArgumentParser(description='Fixing Connections')

    group = parser.add_argument_group('MLPMixer')
    group.add_argument('--dim', metavar='N', type=int, default=512, # 128 for bmlp
                        help=' input dim of channelMLP(default: %(default)s)')
    group.add_argument('--dim_token', type=int, default=196,
                        help=' input dim of tokenMLP(default: %(default)s)')
    group.add_argument('--dim_ppfc', type=int, default=512,
                        help=' output dim of ppfc(default: %(default)s)')

    group.add_argument('--patch_size', type=int, default=16,
                            help=' patch size of inputs(default: %(default)s)')
    group.add_argument('--img_size', type=int, default=224,
                            help='  size of input image(default: %(default)s)')

    group.add_argument('--max_dim',  type=int, default=-1,
                            help=' dim*dim_token*max(1, ef).(default: %(default)s)')    
    group.add_argument('--prod_dim',  type=int, default=100352,
                            help=' max:  dim**2*dim_token+ dim*dim_token**2(default: %(default)s)')


    group.add_argument('--num_connections',  type=int, default=142098432,
                            help=' max:  dim**2*dim_token+ dim*dim_token**2(default: %(default)s)')
    group.add_argument('-ef', '--expansion_factor', type=float, default=4, #0.5
                        help=' expansion_factor  for both MLPMixer and SMixer. If < 0,  each block is FC + Activation. (default: %(default)s)')

    group.add_argument('-efpp', '--expansion_factor_ppfc', type=float, default=-1,
                        help=' expansion_factor for PPFC (default: %(default)s)')

    group.add_argument("--fix", type=int,default=-1,
                        help="0: do nothing, 2: fix dim**2*dim_token + dim_token**2*dim (default: %(default)s)")

    args = parser.parse_args()

    if dim is None or args.fix==-1:
        configure_connections(args)
        print(args)
    elif type(dim) == int:
        args.dim=dim
        configure_connections(args)
        #print(args)
    elif type(dim) == list:
        m = 0
        d0 = 0
        for d in dim:
            args.dim=d
            configure_connections(args)
            if m < args.max_dim:
                m = args.max_dim
                d0 = d

        for d in dim:
            args.dim=d
            configure_connections(args)
            if d == d0:
                print(f"max_dim:{(args.max_dim)} <-- C:{(d)}, S:{(args.dim_token)}, ef:{(args.expansion_factor)} : max")
            else:
                print(f"max_dim:{(args.max_dim)} <-- C:{(d)}, S:{(args.dim_token)}, ef:{(args.expansion_factor)}")
    
    if patch is None:
        if dim is not None: return 
        configure_connections(args)
        print(args)
    elif type(patch) == int:
        args.patch_size = patch
        configure_connections(args)
        print(args)
    elif type(patch) == list:
        m = 0
        p0 = 0
        for p in patch:
            args.patch_size=p
            configure_connections(args)
            #print(args.max_dim)
            if m < args.max_dim:
                m = args.max_dim
                p0 = p

        for p in patch:
            args.patch_size=p
            configure_connections(args)
            if p == p0:
                print(f"max_dim:{(args.max_dim)} <-- p:{(p)}, C:{(args.dim)}, S:{(args.dim_token)}, ef:{(args.expansion_factor)} : max")
            else:
                print(f"max_dim:{(args.max_dim)} <-- p:{(p)}, C:{(args.dim)}, S:{(args.dim_token)}, ef:{(args.expansion_factor)}")
    
if __name__ == "__main__":
    import numpy as np
    import argparse
    main()
